{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考资料\n",
    "[1] [GANs入门系列之（二）用GAN生成MNIST数据集之pytorch实现](https://blog.csdn.net/weixin_41278720/article/details/80861284)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 相关知识\n",
    "PyTorch中`Variable`  \n",
    "`nn.BCELoss()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "from torch.autograd import Variable\n",
    "import os\n",
    " \n",
    "if not os.path.exists('./img'):\n",
    "    os.mkdir('./img')\n",
    " \n",
    " \n",
    "def to_img(x):\n",
    "    out = 0.5 * (x + 1)\n",
    "    out = out.clamp(0, 1)\n",
    "    out = out.view(-1, 1, 28, 28)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "num_epoch = 100\n",
    "z_dimension = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image processing\n",
    "# img_transform = transforms.Compose([\n",
    "#     transforms.ToTensor(),\n",
    "#     transforms.Lambda(lambda x: x.repeat(3,1,1)),\n",
    "#     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n",
    "# ])\n",
    "img_transform = transforms.Compose([\n",
    "transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n",
    "# MNIST dataset\n",
    "mnist = datasets.MNIST(\n",
    "    root='./data/', train=True, transform=img_transform, download=True)\n",
    "# Data loader\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset=mnist, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Discriminator\n",
    "class discriminator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(discriminator, self).__init__()\n",
    "        self.dis = nn.Sequential(\n",
    "            nn.Linear(784, 256),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(256, 256),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(256, 1), \n",
    "            nn.Sigmoid())\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.dis(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator\n",
    "class generator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(generator, self).__init__()\n",
    "        self.gen = nn.Sequential(\n",
    "            nn.Linear(100, 256),\n",
    "            nn.ReLU(True),\n",
    "            nn.Linear(256, 256), \n",
    "            nn.ReLU(True), \n",
    "            nn.Linear(256, 784), \n",
    "            nn.Tanh())\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.gen(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = discriminator()\n",
    "G = generator()\n",
    "if torch.cuda.is_available():\n",
    "    D = D.cuda()\n",
    "    G = G.cuda()\n",
    "# Binary cross entropy loss and optimizer\n",
    "criterion = nn.BCELoss()\n",
    "d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)\n",
    "g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)\n",
    " \n",
    "# Start training\n",
    "for epoch in range(num_epoch):\n",
    "    for i, (img, _) in enumerate(dataloader):\n",
    "        num_img = img.size(0)\n",
    "        # =================train discriminator\n",
    "        img = img.view(num_img, -1)\n",
    "        real_img = Variable(img).cuda()\n",
    "        real_label = Variable(torch.ones(num_img)).cuda()\n",
    "        fake_label = Variable(torch.zeros(num_img)).cuda()\n",
    " \n",
    "        # compute loss of real_img\n",
    "        real_out = D(real_img)\n",
    "        d_loss_real = criterion(real_out, real_label)\n",
    "        real_scores = real_out  # closer to 1 means better\n",
    " \n",
    "        # compute loss of fake_img\n",
    "        z = Variable(torch.randn(num_img, z_dimension)).cuda()\n",
    "        fake_img = G(z)\n",
    "        fake_out = D(fake_img)\n",
    "        d_loss_fake = criterion(fake_out, fake_label)\n",
    "        fake_scores = fake_out  # closer to 0 means better\n",
    " \n",
    "        # bp and optimize\n",
    "        d_loss = d_loss_real + d_loss_fake\n",
    "        d_optimizer.zero_grad()\n",
    "        d_loss.backward()\n",
    "        d_optimizer.step()\n",
    " \n",
    "        # ===============train generator\n",
    "        # compute loss of fake_img\n",
    "        z = Variable(torch.randn(num_img, z_dimension)).cuda()\n",
    "        fake_img = G(z)\n",
    "        output = D(fake_img)\n",
    "        g_loss = criterion(output, real_label)\n",
    " \n",
    "        # bp and optimize\n",
    "        g_optimizer.zero_grad()\n",
    "        g_loss.backward()\n",
    "        g_optimizer.step()\n",
    " \n",
    "        if (i + 1) % 100 == 0:\n",
    "            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '\n",
    "                  'D real: {:.6f}, D fake: {:.6f}'.format(\n",
    "                      epoch, num_epoch, d_loss.data, g_loss.data,\n",
    "                      real_scores.data.mean(), fake_scores.data.mean()))\n",
    "    if epoch == 0:\n",
    "        real_images = to_img(real_img.cpu().data)\n",
    "        save_image(real_images, './img/real_images.png')\n",
    " \n",
    "    fake_images = to_img(fake_img.cpu().data)\n",
    "    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(G.state_dict(), './generator.pth')\n",
    "torch.save(D.state_dict(), './discriminator.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
